import pysam
import numpy as np
import pdb
import os
from tqdm import tqdm
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib
import re
from collections import defaultdict
from scipy.sparse import csr_matrix
import sqlite3
import vcf


bams_path = "/zion/1kg_exome_bams/"
files = os.listdir(bams_path)
def find_file(text):
    options = [i for i in files if i.startswith(text)]
    if len(options):
        return options
    else:
        return None



data_list_readdepths = np.load('../outputs/data_list_readdepths_test_weightedcrossent_embeddingout.npy')[1500:]
data_list_indexes = np.load('../outputs/data_list_indexes_test_weightedcrossent_embeddingout.npy')[1500:]
data_list_cnvnator_preds = np.load('../outputs/data_list_cnvnator_preds_test_weightedcrossent_embeddingout.npy')[1500:]
data_list_xhmm_preds = np.load('../outputs/data_list_xhmm_preds_test_weightedcrossent_embeddingout.npy')[1500:]
data_list_chrs = np.load('../outputs/data_list_chrs_test_weightedcrossent_embeddingout.npy')[1500:]
data_list_names = np.load('../outputs/data_list_names_test_weightedcrossent_embeddingout.npy')[1500:]
data_list_embeddings = np.load('../outputs/embeddings_rev12.npy')[1500:]
data_list_polished = np.load('../outputs/polished_rev12.npy')[1500:]



# mappability_file = open("/home/furkan/deepXCNV/XHMM/scripts/mappability.wig").read().splitlines()


# data_list_gcs = []

# for index in tqdm(range(len(data_list_names))):
#     filename = find_file(data_list_names[index])

#     if len(filename[0]) < len(filename[1]):
#         filename = filename[0]
#     else:
#         filename = filename[1]

#     print(index)

#     try:
#         bam = pysam.AlignmentFile(bams_path+filename)
#     except:
#         gc_percent = 0
#         data_list_gcs.append(gc_percent)
#         continue

#     read_data = bam.fetch(data_list_chrs[index], data_list_indexes[index][0], data_list_indexes[index][1])
#     total_bases = 0
#     gc_bases = 0
#     for read in read_data:
#         seq = read.query_sequence
#         total_bases += len(seq)
#         gc_bases += len([x for x in seq if x =='C' or x == 'G'])
#         gc_percent = float(gc_bases)/total_bases * 100

#     data_list_gcs.append(gc_percent)

# def translate_wig_to_vcf(filename, out_filename):
#     with tqdm(total=os.path.getsize(filename), desc='Translating wig file') as pbar, \
#         open(filename, 'r') as in_f, \
#         open(out_filename, 'w') as out_f:

#         out_f.write("""##fileformat=VCFv4.3
# ##fileDate=20090805
# ##source=myImputationProgramV3.1
# ##phasing=partial
# #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO
# """)

#         chr, pos = '', 0
#         for line in in_f:
#             pbar.update(len(line))

#             if line.startswith('fixedStep'):
#                 chr, pos = re.findall(r'chrom=(.+) start=(\d+) step=1', line)[0]
#                 pos = int(pos)
#             else:
#                 out_f.write(f"{chr}\t{pos}\t.\t.\t.\t{line.strip()}\t.\t.\n")   
#                 pos += 1
            
# def translate_wig_to_sqlite(filename, out_filename):
#     with tqdm(total=os.path.getsize(filename), desc='Translating wig file') as pbar, \
#         open(filename, 'r') as f, \
#         sqlite3.connect(out_filename, isolation_level=None) as con:
#         cur = con.cursor()
#         cur.execute('CREATE TABLE data (chrom TEXT NOT NULL, pos INTEGER NOT NULL, value REAL NOT NULL, PRIMARY KEY (chrom, pos))')

#         chr, pos = '', 0
#         for line in f:
#             pbar.update(len(line))
#             if line.startswith('fixedStep'):
#                 chr, pos = re.findall(r'chrom=(.+) start=(\d+) step=1', line)[0]
#                 pos = int(pos)
#             else:
#                 cur.execute(f"INSERT INTO data VALUES ('{chr}', {pos}, {line.strip()})") 
#                 pos += 1

# # translate_wig_to_vcf('/home/furkan/deepXCNV/XHMM/scripts/mappability.wig', '/home/furkan/deepXCNV/XHMM/scripts/mappability.vcf')
# # translate_wig_to_sqlite('/home/furkan/deepXCNV/XHMM/scripts/mappability.wig', '/home/furkan/deepXCNV/XHMM/scripts/mappability.db')

# data_list_mapp = []
# mappability_reader = vcf.Reader(filename='/home/furkan/deepXCNV/XHMM/scripts/mappability.vcf.gz')

# def compute_value(chr, start, end):
#     records = mappability_reader.fetch(chr, start, end)
#     values = [float(record.QUAL) for record in records]

#     return np.mean(values)


# for index in tqdm(range(len(data_list_names))):
    
#     chr = data_list_chrs[index] #'chr17'
#     start = data_list_indexes[index][0]
#     end = data_list_indexes[index][1]

#     data_list_mapp.append(compute_value(chr, start, end))

data_list_mapp = np.load('data_list_mapp.npy')
# pdb.set_trace()

# np.save("data_list_gcs.npy", np.asarray(data_list_gcs))

data_list_gcs_loaded = np.load("data_list_gcs.npy")

pdb.set_trace()


tsne = TSNE(n_components=2, n_iter=3000, perplexity=20, learning_rate=200)
embedding = tsne.fit_transform(data_list_embeddings)


cmap = matplotlib.cm.get_cmap('plasma')
colors = cmap(data_list_gcs_loaded/100)

figure = plt.figure()
ax = figure.add_subplot(111)
ax.set_title("TSNE applied to CNV read depth embeddings of DECoNT")
sc = ax.scatter(embedding[:,0], embedding[:,1], c=data_list_mapp, cmap='plasma')
ax.set_xlabel("Component 1")
ax.set_ylabel("Component 2")
cbar = plt.colorbar(sc)
cbar.ax.get_yaxis().labelpad = 15
cbar.ax.set_ylabel('Mappability Score', rotation=270)
plt.savefig("mappability_tsne.png")






pdb.set_trace()